package org.fpp.proxy;

import javax.tools.DiagnosticCollector;
import javax.tools.FileObject;
import javax.tools.ForwardingJavaFileManager;
import javax.tools.JavaCompiler;
import javax.tools.JavaFileManager;
import javax.tools.JavaFileObject;
import javax.tools.SimpleJavaFileObject;
import javax.tools.StandardJavaFileManager;
import javax.tools.StandardLocation;
import javax.tools.ToolProvider;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
import java.util.Collections;
import java.util.Map;

/**
 * @author bigbird-0101
 * @date 2024-06-20 22:52
 */
public class CodeCompilerUtil {
    public static byte[] compileCodeFromString(String className, String code, Map<String, byte[]> inMemoryClasses) {
        JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
        if (compiler == null) {
            throw new IllegalStateException("JDK required to compile at runtime");
        }

        DiagnosticCollector<JavaFileObject> diagnostics = new DiagnosticCollector<>();
        StandardJavaFileManager stdFileMgr = compiler.getStandardFileManager(null, null, null);
        MemoryFileManager fileManager = new MemoryFileManager(inMemoryClasses, stdFileMgr);
        JavaFileObject sourceFile = new StringJavaSource(className, code);
        JavaCompiler.CompilationTask task = compiler.getTask(null, fileManager, diagnostics, null, null, Collections.singletonList(sourceFile));
        boolean success = task.call();
        if (!success) {
            throw new RuntimeException("Compilation failed: " + diagnostics.getDiagnostics());
        }

        return inMemoryClasses.get(className);
    }

    static private class MemoryFileManager extends ForwardingJavaFileManager<JavaFileManager> {
        private final Map<String, byte[]> map;

        MemoryFileManager(Map<String, byte[]> map, JavaFileManager delegate) {
            super(delegate);
            this.map = map;
        }

        @Override
        public JavaFileObject getJavaFileForOutput(Location location, String className,
                                                   JavaFileObject.Kind kind, FileObject sibling) throws IOException {
            if (location == StandardLocation.CLASS_OUTPUT && kind == JavaFileObject.Kind.CLASS) {
                return createInMemoryClassFile(className);
            } else {
                return super.getJavaFileForOutput(location, className, kind, sibling);
            }
        }

        private JavaFileObject createInMemoryClassFile(String className) {
            URI uri = URI.create("memory:///" + className.replace('.', '/') + ".class");
            return new SimpleJavaFileObject(uri, JavaFileObject.Kind.CLASS) {
                @Override
                public OutputStream openOutputStream() {
                    return new ByteArrayOutputStream() {
                        @Override
                        public void close() throws IOException {
                            super.close();
                            map.put(className, toByteArray());
                        }
                    };
                }
            };
        }
    }

    static class StringJavaSource extends SimpleJavaFileObject {
        final String code;

        StringJavaSource(String name, String code) {
            super(URI.create("string:///" + name.replaceAll("\\.", "/") + Kind.SOURCE.extension), Kind.SOURCE);
            this.code = code;
        }

        @Override
        public CharSequence getCharContent(boolean ignoreEncodingErrors) {
            return code;
        }
    }
}
